% INFDS_TRAIN_NN  Demonstrate how the ReBEL toolkit is used to train a neural network
%                 in efficient "low-level" mode.
%
%   Direct (InferenceDS level) description of a parameter estimation inference data
%   structure needed to train a 2 layer MLP on the standard XOR classification problem
%   using a Sigma-Point Kalman Filter (ukf, cdkf, srcdkf or srukf). Only the minimum
%   amount of needed InferenceDS fields are defined.
%
%   --- NOTE  :  This file is needed by 'train_nn.m'. See discussion in 'train_nn.m"
%
%   Copyright (c) Oregon Health & Science University (2006)
%
%   This file is part of the ReBEL Toolkit. The ReBEL Toolkit is available free for
%   academic use only (see included license file) and can be obtained from
%   http://choosh.csee.ogi.edu/rebel/.  Businesses wishing to obtain a copy of the
%   software should contact rebel@csee.ogi.edu for commercial licensing information.
%
%   See LICENSE (which should be part of the main toolkit distribution) for more
%   detail.

%=============================================================================================

function InferenceDS = infds_train_nn

    %--- Setup InferenceDS fields needed by inference algorithms

    InferenceDS.type    = 'InferenceDS';          % data structure type identifier tag
    InferenceDS.inftype = 'parameter';            % this InferenceDS data structure will be used for parameter estimation

    InferenceDS.statedim  = (2*4+4 + 4*1+1);      % number of free parameters in 2-2-1 neural network
    InferenceDS.obsdim    = 1;
    InferenceDS.U1dim     = 0;
    InferenceDS.U2dim     = 2;
    InferenceDS.Vdim      = InferenceDS.statedim;
    InferenceDS.Ndim      = InferenceDS.obsdim;
    InferenceDS.ffun      = @ffun;
    InferenceDS.hfun      = @hfun;
    InferenceDS.innovation = @innovation;


    %--- Store some extra problem specific info in InferenceDS to speed up later calculation

    InferenceDS.nodes  = [2 4 1];       % simple 2 layer ReBEL MLP neural net with 2 inputs, 2 hidden layers and 1 output unit
    InferenceDS.olType = 'tanh';        % sigmoidal (hyperbolic tangent) output unit activation.. .works well for clasification
                                        % problems.

%============================================================================================
% Generic State transition function for parameter estimation. Basically a random walk driven
% by artificial process noise (this speeds up convergence). Remember to adapt (anneal) the
% process noise covariance by setting 'pNoiseDS.adaptMethod' to something useful in the main
% calling script.
%
%  Input
%          InfDS       :     stripped down InferenceDS datastructure as defined above
%          state       :     current state of system (in this case the parameters become the
%                            new state variable
%          V           :     process noise vector (all SPKFs need this)
%          U1          :     exogenous input to state transition function (not needed for
%                            parameter estimation, but we must comply with the interface
%                            expected by all estimation algorithms).
%
%  Output
%          new_state   :     system state at next time instant
%

function new_state = ffun(InfDS, state, V, U1)


    new_state      = state;

    if isempty(V)
        new_state = state;
    else
        new_state = state + V;
    end

%============================================================================================
% Problem specific state observation function. This is where the actual 'parameterised' function
% mapping takes place as a nonlinear observation on the state vector (which is the parameters
% of this mapping function). This can easily be adapted to ANY functional form (i.e. Netlab
% neural networks, Mathworks NN toolbox neural nets, etc. etc.)
%
%  Input
%          InfDS       :     stripped down InferenceDS datastructure as defined above
%          state       :     current state of system (in this case the parameters become the
%                            new state variable
%          N           :     observation noise vector (all SPKFs need this)
%          U2          :     exogenous input to state observation function. This is where you
%                            pass in the original clean inputs to the neural network.
%
%  Output
%          observ      :     output generated by neural network for current state (parameters) and
%                            current input (U2)
%

function observ = hfun(InfDS, state, N, U2)


    numInputs = size(state,2);            % These functions must be able to operate on more than
                                          % one input vector (i.e. block mode). This is a requirement
                                          % in order to use any of the SPKF based algorithms.

    observ = zeros(InfDS.obsdim,numInputs); % preallocate output buffer

    for k=1:numInputs

        % Call 'mlpff' to calculate the NN output for the current parameter vector 'state(:,k)' and
        % NN input 'U2(:,k)'. 'mlpff' unpacks the parameter vector internally. This operation can further
        % be speeded up by unpacking the parameters and calculating the network output directly (in-place)
        % here without calling any functions.

        observ(:,k) = mlpff(InfDS.olType, InfDS.nodes, U2(:,k), state(:,k));

    end

    %-- Add measurement noise if present       (needed by all SPKF based algorithms)
    if ~isempty(N)
        observ = observ + N;
    end


%======================================================================================================
function innov = innovation(InferenceDS, obs, observ)

    %  Calculates the innovation signal (difference) between the
    %  output of HFUN, i.e. OBSERV (the predicted system observation) and an actual
    %  'real world' observation OBS.

    innov = obs - observ;
